#!/usr/bin/env python
# -*- coding:utf-8 -*-
import os
import pickle

import keras.utils
import numpy as np
import six
from gensim.models.word2vec import Word2Vec
from keras.preprocessing.text import Tokenizer
from sklearn.model_selection import train_test_split

from const import DEFAULT_CONFIG, EMBEDDING_DIM, CPU_COUNT, MODEL_PATH
from util import load_word2vec, load_tokenizer, tokenize
from util.db import Mysql
from util.jb import fc

__all__ = ['load_data', 'word2vec_train', 'data_set']

_SQL = 'select id, content, tag from a10jqka_article order by id;'
_SQL_WITH_LAST_ID = 'select id, content, tag from a10jqka_article where id > %d order by id;'


def load_data(last_id=None, **kwargs):
    """从mysql读取数据并分词"""
    db_conf = DEFAULT_CONFIG.get('db', {})
    db_conf.update(kwargs.get('db'))
    db = Mysql(**db_conf)
    db.init()
    selector = db.selector()
    if last_id:
        data = selector.send(_SQL_WITH_LAST_ID % (last_id,))
    else:
        data = selector.send(_SQL)
    db.close_selector(selector)
    texts = []
    labels = []
    for d in data:
        if len(d[1]) < 20:
            continue
        texts.append(fc(d[1]))
        labels.append(d[2] if int(d[2]) != -1 else 2)
    next_id = data[-1][0]
    return texts, labels, next_id


def word2vec_train(texts, conf, update=True, **kwargs):
    """创建词语字典，并返回每个词语的索引，词向量，以及每个句子所对应的词语索引"""
    w2v_path = os.path.join(MODEL_PATH, '%s.pkl' % (kwargs.get('word2vec'),))
    if update and os.path.exists(w2v_path):
        model = load_word2vec(w2v_path)
    else:
        model = Word2Vec(size=EMBEDDING_DIM,
                         min_count=conf.get('exposures'),
                         window=conf.get('window_size'),
                         workers=CPU_COUNT,
                         iter=conf.get('iterations'))
    _t_l = [_t.split(' ') for _t in texts]
    model.build_vocab(_t_l)  # input: list
    model.train(texts,  epochs=model.epochs, total_examples=model.corpus_count)
    if kwargs.get('save'):
        with open(os.path.join(MODEL_PATH, '%s.pkl' % (kwargs.get('word2vec'),)), 'wb') as _fp:
            model.save(_fp, pickle_protocol=2)
    word_index, word_vectors, combined = create_tokenize(_model=model, texts=texts, max_len=conf.get('max_len'),
                                                         **kwargs)
    return word_index, word_vectors, combined


def data_set(word_index, word_vectors, combined, y):
    """初始化数据"""
    n_symbols = len(word_index) + 1  # 所有单词的索引数
    embedding_weights = np.zeros((n_symbols, EMBEDDING_DIM))  # 初始化 索引为0的词语，词向量全为0
    for word, index in word_index.items():  # 从索引为1的词语开始，对每个词语对应其词向量
        _v = word_vectors.get(word, None)
        if _v is not None:
            embedding_weights[index, :] = _v
    x_train, x_test, y_train, y_test = train_test_split(combined, y, test_size=0.2)
    y_train = keras.utils.to_categorical(y_train, num_classes=3)
    y_test = keras.utils.to_categorical(y_test, num_classes=3)
    return n_symbols, embedding_weights, x_train, y_train, x_test, y_test


def create_tokenize(_model=None, texts=None, max_len=None, update=True, **kwargs):
    """创建字典"""
    t_path = os.path.join(MODEL_PATH, '%s.pkl' % (kwargs.get('tokenize'),))
    if texts is None or _model is None:
        raise Exception('No data provided...')
    if update and os.path.exists(t_path):
        tokenizer = load_tokenizer(t_path)
    else:
        tokenizer = Tokenizer()
    tokenizer.fit_on_texts(texts)
    word_index = tokenizer.word_index
    word_vectors = {}
    for word in six.iterkeys(word_index):
        try:
            word_vectors[word] = _model[word]
        except KeyError:
            pass
    if kwargs.get('save'):
        pickle.dump(tokenizer, open(t_path, 'wb'), protocol=2)  # 保存词典
    combined = tokenize(tokenizer, texts, max_len)
    return word_index, word_vectors, combined
